{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q smdebug==0.7.2\n",
    "!pip install -q sagemaker-experiments==0.1.11"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify the S3 Location of the Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r scikit_processing_job_s3_output_prefix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(scikit_processing_job_s3_output_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# HACK:  Remove this once we test the new `05_prepare/` and remove the `is_real_example`\n",
    "\n",
    "scikit_processing_job_s3_output_prefix = 'sagemaker-scikit-learn-2020-04-10-04-44-09-647'\n",
    "\n",
    "print('Previous Scikit Processing Job Name: {}'.format(scikit_processing_job_s3_output_prefix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix_train = '{}/output/bert-train'.format(scikit_processing_job_s3_output_prefix)\n",
    "prefix_validation = '{}/output/bert-validation'.format(scikit_processing_job_s3_output_prefix)\n",
    "prefix_test = '{}/output/bert-test'.format(scikit_processing_job_s3_output_prefix)\n",
    "\n",
    "train_s3_uri = 's3://{}/{}'.format(bucket, prefix_train)\n",
    "validation_s3_uri = 's3://{}/{}'.format(bucket, prefix_validation)\n",
    "test_s3_uri = 's3://{}/{}'.format(bucket, prefix_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(train_s3_uri)\n",
    "!aws s3 ls $train_s3_uri/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_input_train_data = sagemaker.s3_input(s3_data=train_s3_uri, distribution='ShardedByS3Key') \n",
    "s3_input_validation_data = sagemaker.s3_input(s3_data=validation_s3_uri, distribution='ShardedByS3Key')\n",
    "s3_input_test_data = sagemaker.s3_input(s3_data=test_s3_uri, distribution='ShardedByS3Key')\n",
    "\n",
    "print(s3_input_train_data.config)\n",
    "print(s3_input_validation_data.config)\n",
    "print(s3_input_test_data.config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!cat src/tf_bert_reviews.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Debugger Rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.debugger import Rule, CollectionConfig, rule_configs\n",
    "\n",
    "model_output_path = 's3://{}/models/tf2-bert'.format(bucket)\n",
    "\n",
    "rules=[\n",
    "        Rule.sagemaker(\n",
    "            rule_configs.loss_not_decreasing(),\n",
    "            rule_parameters={\n",
    "                'collection_names': 'losses,metrics',\n",
    "                'use_losses_collection': 'true',\n",
    "                'num_steps': '5',\n",
    "                'diff_percent': '5'\n",
    "            },\n",
    "            collections_to_save=[\n",
    "                CollectionConfig(name='losses',\n",
    "                                 parameters={\n",
    "                                     'save_interval': '100',\n",
    "                                 }),\n",
    "                CollectionConfig(name='metrics',\n",
    "                                 parameters={\n",
    "                                     'save_interval': '100',\n",
    "                                 })\n",
    "            ]\n",
    "        ),\n",
    "        Rule.sagemaker(\n",
    "            rule_configs.overtraining(),\n",
    "            rule_parameters={\n",
    "                'collection_names': 'losses,metrics',\n",
    "                'patience_train': '5',\n",
    "                'patience_validation': '10',\n",
    "                'delta': '0.1'\n",
    "            },\n",
    "            collections_to_save=[\n",
    "                CollectionConfig(name='losses',\n",
    "                                 parameters={\n",
    "                                     'save_interval': '100',\n",
    "                                 }),\n",
    "                CollectionConfig(name='metrics',\n",
    "                                 parameters={\n",
    "                                     'save_interval': '100',\n",
    "                                 })\n",
    "            ]\n",
    "        )\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Hyper-Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs=2\n",
    "train_batch_size=128\n",
    "validation_batch_size=128\n",
    "test_batch_size=128\n",
    "train_steps_per_epoch=100\n",
    "validation_steps=100\n",
    "test_steps=100\n",
    "train_instance_count=2\n",
    "train_instance_type='ml.p3.2xlarge'\n",
    "train_volume_size=1800\n",
    "use_xla=True\n",
    "use_amp=True\n",
    "max_seq_length=128\n",
    "freeze_bert_layer=True\n",
    "enable_sagemaker_debugger=False                   # Enable SM Debugger\n",
    "#use_parameter_server=(train_instance_count >= 2) # Use Parameter Server if distributed cluster\n",
    "input_mode='Pipe'                                 # 'File' or 'Pipe' Mode\n",
    "run_validation=True\n",
    "run_test=True\n",
    "run_sample_predictions=True\n",
    "#disable_eager_execution=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Track Experiments\n",
    "import time\n",
    "unique_id = '{}-{}'.format(input_mode, int(time.time()))\n",
    "\n",
    "from smexperiments.experiment import Experiment\n",
    "experiment=Experiment.create(\n",
    "    experiment_name='train-reviews-bert-{}'.format(unique_id),\n",
    "    description='Train Reviews BERT', \n",
    "    sagemaker_boto_client=sm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from smexperiments.tracker import Tracker\n",
    "tracker_display_name='Train-Reviews-BERT-Tracker-{}'.format(unique_id)\n",
    "print(tracker_display_name)\n",
    "\n",
    "tracker = Tracker.create(display_name=tracker_display_name, sagemaker_boto_client=sm)\n",
    "tracker.log_parameters({\n",
    "    'epochs': epochs,\n",
    "    'train_batch_size': train_batch_size,\n",
    "    'validation_batch_size': validation_batch_size,\n",
    "    'test_batch_size': test_batch_size,\n",
    "    'train_steps_per_epoch': train_steps_per_epoch,\n",
    "    'validation_steps': validation_steps,\n",
    "    'test_steps': test_steps,\n",
    "    'train_instance_count': train_instance_count,\n",
    "    'train_instance_type': train_instance_type,\n",
    "    'train_volume_size': train_volume_size,\n",
    "    'use_xla': use_xla,\n",
    "    'use_amp': use_amp,\n",
    "    'max_seq_length': max_seq_length,\n",
    "    'freeze_bert_layer': freeze_bert_layer,\n",
    "    'enable_sagemaker_debugger': enable_sagemaker_debugger,\n",
    "#    'use_parameter_server': use_parameter_server,\n",
    "    'input_mode': input_mode, # 'File' or 'Pipe'\n",
    "    'run_validation': run_validation,\n",
    "    'run_test': run_test,\n",
    "    'run_sample_predictions': run_sample_predictions,    \n",
    "#    'disable_eager_execution': disable_eager_execution,\n",
    "})\n",
    "# we can log the s3 uri to the dataset we just uploaded\n",
    "tracker.log_input(name='reviews_dataset_train', media_type='s3/uri', value=train_s3_uri)\n",
    "tracker.log_input(name='reviews_dataset_validation', media_type='s3/uri', value=validation_s3_uri)\n",
    "tracker.log_input(name='reviews_dataset_test', media_type='s3/uri', value=test_s3_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from smexperiments.trial import Trial\n",
    "trial_name='train-reviews-bert-training-job-{}'.format(unique_id)\n",
    "trial = Trial.create(trial_name=trial_name, experiment_name=experiment.experiment_name, sagemaker_boto_client=sm)\n",
    "trial.add_trial_component(tracker.trial_component)\n",
    "trial_component_display_name='Train-Reviews-BERT-Trial-{}'.format(unique_id)\n",
    "experiment_config={'ExperimentName': experiment.experiment_name,\n",
    "                   'TrialName': trial.trial_name,\n",
    "                   'TrialComponentDisplayName': trial_component_display_name}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.tensorflow import TensorFlow\n",
    "\n",
    "#tensorboard_output_config = TensorBoardOutputConfig('s3://smdebug-dev-demo-pdx/mnist/tensorboard')\n",
    "\n",
    "metrics_definitions = [\n",
    "     {'Name': 'train:loss', 'Regex': 'loss: ([0-9\\\\.]+)'},\n",
    "     {'Name': 'train:accuracy', 'Regex': 'accuracy: ([0-9\\\\.]+)'},\n",
    "     {'Name': 'validation:loss', 'Regex': 'val_loss: ([0-9\\\\.]+)'},\n",
    "     {'Name': 'validation:accuracy', 'Regex': 'val_accuracy: ([0-9\\\\.]+)'},\n",
    "]\n",
    "\n",
    "estimator = TensorFlow(entry_point='tf_bert_reviews.py',\n",
    "                            source_dir='src',\n",
    "                            role=role,\n",
    "                            train_instance_count=train_instance_count, # Make sure you have at least this number of input files or the ShardedByS3Key distibution strategy will fail the job due to no data available\n",
    "                            train_instance_type=train_instance_type,\n",
    "                            train_volume_size=train_volume_size,\n",
    "                            py_version='py3',\n",
    "                            framework_version='2.1.0',\n",
    "                            output_path=model_output_path,\n",
    "                            hyperparameters={'epochs': epochs,\n",
    "                                             'train-batch-size': train_batch_size,\n",
    "                                             'validation-batch-size': validation_batch_size,\n",
    "                                             'test-batch-size': test_batch_size,                                             \n",
    "                                             'train-steps-per-epoch': train_steps_per_epoch,\n",
    "                                             'validation-steps': validation_steps,\n",
    "                                             'test-steps': test_steps,\n",
    "                                             'use-xla': use_xla,\n",
    "                                             'use-amp': use_amp,                                             \n",
    "                                             'max-seq-length': max_seq_length,\n",
    "                                             'freeze-bert-layer': freeze_bert_layer,\n",
    "                                             'enable-sagemaker-debugger': enable_sagemaker_debugger,\n",
    "                                             'run-validation': run_validation,\n",
    "                                             'run-test': run_test,\n",
    "                                             'run-sample-predictions': run_sample_predictions},\n",
    "#                                             'disable-eager-execution': disable_eager_execution},\n",
    "#                            distributions={'parameter_server': {'enabled': use_parameter_server}},\n",
    "                            input_mode=input_mode,\n",
    "#                            enable_cloudwatch_metrics=True,\n",
    "                            metric_definitions=metrics_definitions,\n",
    "                            rules=rules,\n",
    "#                            tensorboard_output_config=tensorboard_output_config\n",
    "                            train_max_run=7200 # max 2 hours * 60 minutes seconds per hour * 60 seconds per minute\n",
    "                           )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.fit(inputs={'train': s3_input_train_data, \n",
    "                      'validation': s3_input_validation_data,\n",
    "                      'test': s3_input_test_data\n",
    "                     },\n",
    "                     experiment_config=experiment_config,                   \n",
    "                     wait=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_job_name = estimator.latest_training_job.name\n",
    "print('training_job_name:  {}'.format(training_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}\">Training Job</a> After About 5 Minutes</b>'.format(region, training_job_name)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>'.format(region, training_job_name)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "training_job_s3_output_prefix = 'models/tf2-bert/{}'.format(training_job_name) # 'models/tf-bert/script-mode/training-runs/{}'.format(training_job_name)\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>'.format(bucket, training_job_s3_output_prefix, region)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Track Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.analytics import ExperimentAnalytics\n",
    "\n",
    "# # Might need to convert ' => \"\n",
    "# search_expression = {\n",
    "#     'Filters':[\n",
    "#         {\n",
    "#             'Name': 'DisplayName',\n",
    "#             'Operator': 'Equals',\n",
    "#             'Value': 'Training'\n",
    "#         }\n",
    "#     ]\n",
    "# }\n",
    "\n",
    "trial_component_analytics = ExperimentAnalytics(\n",
    "    sagemaker_session=sess, \n",
    "    experiment_name=experiment.experiment_name,\n",
    "#    search_expression=search_expression,\n",
    "    sort_by='metrics.validation:accuracy.max',\n",
    "    sort_order='Descending',\n",
    "    metric_names=['validation:accuracy'],\n",
    "    parameter_names=['epochs', 'train_batch_size']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "analytics_table = trial_component_analytics.dataframe()\n",
    "analytics_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# search_expression={\n",
    "#     'Filters':[{\n",
    "#         'Name': 'Parents.TrialName',\n",
    "#         'Operator': 'Equals',\n",
    "#         'Value': ??\n",
    "#     }]\n",
    "# },\n",
    "\n",
    "lineage_table = ExperimentAnalytics(\n",
    "    sagemaker_session=sess,\n",
    "    experiment_name=experiment.experiment_name,\n",
    "#    search_expression=search_expression,\n",
    "    sort_by=\"CreationTime\",\n",
    "    sort_order=\"Ascending\",\n",
    ")\n",
    "lineage_table.dataframe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyze Debugger Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# estimator.latest_training_job.rule_job_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# from smdebug.trials import create_trial\n",
    "\n",
    "# # this is where we create a Trial object that allows access to saved tensors\n",
    "# trial = create_trial(estimator.latest_job_debugger_artifacts_path())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Lambda Function to Stop the TrainingJob Early\n",
    "In your AWS console, go to Lambda Management Console,\n",
    "Create a new function by hitting Create Function,\n",
    "Choose the language as Python 3.7 and put in the following sample code for stopping the training job if one of the Rule statuses is \"IssuesFound\".\n",
    "\n",
    "### Cloudwatch Events for Rules\n",
    "Rule status changes in a training job trigger CloudWatch Events. These events can be acted upon by configuring a CloudWatch Rule (different from Amazon SageMaker Debugger Rule) to trigger each time a Debugger Rule changes status. In this notebook we'll go through how you can create a CloudWatch Rule to direct Training Job State change events to a lambda function that stops the training job in case a rule triggers and has status \"IssuesFound\".\n",
    "\n",
    "Create a new execution role for the Lambda, and\n",
    "In your IAM console, search for the role and attach \"AmazonSageMakerFullAccess\" policy to the role. This is needed for the code in your Lambda function to stop the training job.\n",
    "\n",
    "\n",
    "### Create a CloudWatch Rule to Trigger a Lamba\n",
    "In your AWS Console, go to CloudWatch and select Rule from the left column,\n",
    "Hit Create Rule. The console will redirect you to the Rule creation page,\n",
    "For the Service Name, select \"SageMaker\".\n",
    "For the Event Type, select \"SageMaker Training Job State Change\".\n",
    "In the Targets select the Lambda function you created above, and\n",
    "For this example notebook, we'll leave everything as is.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create the Lambda"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "import json\n",
    "import boto3\n",
    "import logging\n",
    "\n",
    "def lambda_handler(event, context):\n",
    "    training_job_name = event.get(\"detail\").get(\"TrainingJobName\")\n",
    "    eval_statuses = event.get(\"detail\").get(\"DebugRuleEvaluationStatuses\", None)\n",
    "\n",
    "    if eval_statuses is None or len(eval_statuses) == 0:\n",
    "        logging.info(\"Couldn't find any debug rule statuses, skipping...\")\n",
    "        return {\n",
    "            'statusCode': 200,\n",
    "            'body': json.dumps('Nothing to do')\n",
    "        }\n",
    "\n",
    "    client = boto3.client('sagemaker')\n",
    "\n",
    "    for status in eval_statuses:\n",
    "        if status.get(\"RuleEvaluationStatus\") == \"IssuesFound\":\n",
    "            logging.info(\n",
    "                'Evaluation of rule configuration {} resulted in \"IssuesFound\". '\n",
    "                'Attempting to stop training job {}'.format(\n",
    "                    status.get(\"RuleConfigurationName\"), training_job_name\n",
    "                )\n",
    "            )\n",
    "            try:\n",
    "                client.stop_training_job(\n",
    "                    TrainingJobName=training_job_name\n",
    "                )\n",
    "            except Exception as e:\n",
    "                logging.error(\n",
    "                    \"Encountered error while trying to \"\n",
    "                    \"stop training job {}: {}\".format(\n",
    "                        training_job_name, str(e)\n",
    "                    )\n",
    "                )\n",
    "                raise e\n",
    "    return None\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Create a new execution role for the Lambda, and\n",
    "* In your IAM console, search for the role and attach \"AmazonSageMakerFullAccess\" policy to the role. This is needed for the code in your Lambda function to stop the training job.\n",
    "\n",
    "#### Create a CloudWatch Rule\n",
    "\n",
    "* In your AWS Console, go to CloudWatch and select Rule from the left column,\n",
    "* Hit Create Rule. The console will redirect you to the Rule creation page,\n",
    " * For the Service Name, select \"SageMaker\".\n",
    " * For the Event Type, select \"SageMaker Training Job State Change\".\n",
    "* In the Targets select the Lambda function you created above, and\n",
    "* For this example notebook, we'll leave everything as is."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SageMaker kicked off rule evaluation jobs, one for each of the SageMaker rules - `Overtraining` and `LossNotDecreasing` as specified in the estimator. If we setup a CloudWatch Rule to stop the training job, we would see the `TrainingJobStatus` change to `Stopped` once the `RuleEvaluationStatus` for changes to `IssuesFound`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # This utility gives the link to monitor the CW event\n",
    "# def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):\n",
    "#         \"\"\"Helper function to get the rule job name\"\"\"\n",
    "#         return \"{}-{}-{}\".format(\n",
    "#             training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]\n",
    "#         )\n",
    "    \n",
    "# def _get_cw_url_for_rule_job(rule_job_name, region):\n",
    "#     return \"https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix\".format(region, region, rule_job_name)\n",
    "\n",
    "\n",
    "# def get_rule_jobs_cw_urls(estimator):\n",
    "#     region = boto3.Session().region_name\n",
    "#     training_job = estimator.latest_training_job\n",
    "#     training_job_name = training_job.describe()[\"TrainingJobName\"]\n",
    "#     rule_eval_statuses = training_job.describe()[\"DebugRuleEvaluationStatuses\"]\n",
    "    \n",
    "#     result={}\n",
    "#     for status in rule_eval_statuses:\n",
    "#         if status.get(\"RuleEvaluationJobArn\", None) is not None:\n",
    "#             rule_job_name = _get_rule_job_name(training_job_name, status[\"RuleConfigurationName\"], status[\"RuleEvaluationJobArn\"])\n",
    "#             result[status[\"RuleConfigurationName\"]] = _get_cw_url_for_rule_job(rule_job_name, region)\n",
    "#     return result\n",
    "\n",
    "# get_rule_jobs_cw_urls(estimator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# estimator.latest_training_job.describe()[\"TrainingJobStatus\"]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Hyper-Parameter Ranges to Explore\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.tuner import IntegerParameter\n",
    "from sagemaker.tuner import ContinuousParameter\n",
    "from sagemaker.tuner import CategoricalParameter\n",
    "from sagemaker.tuner import HyperparameterTuner\n",
    "                                                \n",
    "hyperparameter_ranges = {\n",
    "    'use-xla': CategoricalParameter([True, False]),\n",
    "    'use-amp': CategoricalParameter([True, False]),\n",
    "#    'train-batch-size': CategoricalParameter([128]),\n",
    "#    'validation-batch-size': CategoricalParameter([128]),\n",
    "#    'test-batch-size': CategoricalParameter([128]),\n",
    "    'epochs': IntegerParameter(2, 16, scaling_type='Logarithmic'),\n",
    "    'train-steps-per-epoch': IntegerParameter(10, 1000, scaling_type='Logarithmic'),\n",
    "#    'validation_steps': CategoricalParameter([100]),\n",
    "#    'test-steps': CategoricalParameter([100]),\n",
    "#    'max-seq-length': CategoricalParameter([128]),\n",
    "    'freeze-bert-layer': CategoricalParameter([True, False]),\n",
    "#    'disable-eager-execution': CategoricalParameter([True, False])\n",
    "#    'enabled-sagemaker-debugger': CategoricalParameter([True])\n",
    "}\n",
    "\n",
    "objective_metric_name = 'validation:accuracy'\n",
    "\n",
    "tuner = HyperparameterTuner(\n",
    "    estimator=estimator,\n",
    "    objective_metric_name=objective_metric_name,\n",
    "    hyperparameter_ranges=hyperparameter_ranges,\n",
    "    metric_definitions=metrics_definitions,\n",
    "    max_jobs=12,\n",
    "    max_parallel_jobs=3,\n",
    "    strategy='Bayesian'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start Tuning Job"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tuner.fit({'train': s3_input_train_data, \n",
    "           'validation': s3_input_validation_data,\n",
    "           'test': s3_input_test_data\n",
    "          }, include_cls_metadata=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "job_description = sm.describe_hyper_parameter_tuning_job(\n",
    "    HyperParameterTuningJobName=tuner.latest_tuning_job.job_name\n",
    ")\n",
    "\n",
    "status = job_description['HyperParameterTuningJobStatus']\n",
    "\n",
    "print('\\n')\n",
    "print(status)\n",
    "print('\\n')\n",
    "pprint(job_description)\n",
    "\n",
    "if status != 'Completed':\n",
    "    job_count = job_description['TrainingJobStatusCounters']['Completed']\n",
    "    print('Not yet complete, but {} jobs have completed.')\n",
    "    \n",
    "    if job_description.get('BestTrainingJob', None):\n",
    "        print(\"Best candidate:\")\n",
    "        pprint(job_description['BestTrainingJob']['TrainingJobName'])\n",
    "        pprint(job_description['BestTrainingJob']['FinalHyperParameterTuningJobObjectiveMetric'])\n",
    "    else:\n",
    "        print(\"No training jobs have reported results yet.\")    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Wait 30-60 seconds for this..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.analytics import HyperparameterTuningJobAnalytics\n",
    "\n",
    "hp_results = HyperparameterTuningJobAnalytics(\n",
    "    sagemaker_session=sess, \n",
    "    hyperparameter_tuning_job_name=tuner.latest_tuning_job.name\n",
    ")\n",
    "\n",
    "df_results = hp_results.dataframe()\n",
    "\n",
    "df_results.sort_values('FinalObjectiveValue', ascending=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results.sort_values('FinalObjectiveValue', ascending=0).head(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download and Load the Trained Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dir = './models'\n",
    "outputs_dir = './outputs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the model and output artifacts from AWS S3\n",
    "!aws s3 cp $model_output_path/$training_job_name/output/model.tar.gz $models_dir/model.tar.gz\n",
    "!aws s3 cp $model_output_path/$training_job_name/output/output.tar.gz $outputs_dir/output.tar.gz\n",
    "\n",
    "#!aws s3 cp s3://sagemaker-us-east-1-835319576252/models/tf2-bert/tensorflow-training-2020-04-09-21-47-55-902/output/model.tar.gz ./models/model.tar.gz\n",
    "#!aws s3 cp s3://sagemaker-us-east-1-835319576252/models/tf2-bert/tensorflow-training-2020-04-09-21-47-55-902/output/output.tar.gz ./output/output.tar.gz\n",
    "#!aws s3 cp s3://sagemaker-us-east-1-835319576252/models/tf-bert/script-mode/training-runs/tensorflow-training-2020-03-24-04-41-39-405/output/model.tar.gz ./models/tf2-bert/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "import pickle as pkl\n",
    "\n",
    "#!ls -al ./models\n",
    "\n",
    "tar = tarfile.open('{}/model.tar.gz'.format(models_dir))\n",
    "tar.extractall(path=models_dir)\n",
    "tar.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tar = tarfile.open('{}/output.tar.gz'.format(outputs_dir))\n",
    "tar.extractall(path=outputs_dir)\n",
    "tar.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!ls -al $models_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!ls -al $outputs_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install TensorFlow and Transformers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install -q pip --upgrade\n",
    "!pip install -q wrapt --upgrade --ignore-installed\n",
    "!pip install -q tensorflow==2.1.0\n",
    "!pip install transformers==2.7.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_model_dir = '{}/transformer/pretrained/'.format(models_dir)\n",
    "\n",
    "!ls -al $transformer_model_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat $transformer_model_dir/config.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from transformers import TFDistilBertForSequenceClassification\n",
    "\n",
    "loaded_model = TFDistilBertForSequenceClassification.from_pretrained(transformer_model_dir,\n",
    "                                                                     id2label={\n",
    "                                                                       0: 1,\n",
    "                                                                       1: 2,\n",
    "                                                                       2: 3,\n",
    "                                                                       3: 4,\n",
    "                                                                       4: 5\n",
    "                                                                     },\n",
    "                                                                     label2id={\n",
    "                                                                       1: 0,\n",
    "                                                                       2: 1,\n",
    "                                                                       3: 2,\n",
    "                                                                       4: 3,\n",
    "                                                                       5: 4\n",
    "                                                                     })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DistilBertTokenizer\n",
    "\n",
    "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
    "\n",
    "inference_device = -1 # CPU: -1, GPU: 0\n",
    "print('inference_device {}'.format(inference_device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TextClassificationPipeline\n",
    "\n",
    "inference_pipeline = TextClassificationPipeline(model=loaded_model, \n",
    "                                                tokenizer=tokenizer,\n",
    "                                                framework='tf',\n",
    "                                                device=inference_device) # -1 is CPU, >= 0 is GPU\n",
    "\n",
    "print(\"\"\"I loved it!  I will recommend this to everyone.\"\"\", inference_pipeline(\"\"\"I loved it!  I will recommend this to everyone.\"\"\"))\n",
    "print(\"\"\"Really bad.  I hope they don't make this anymore.\"\"\", inference_pipeline(\"\"\"Really bad.  I hope they don't make this anymore.\"\"\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Uncomment to Start Tensorboard\n",
    "Let's start Tensorboard and point to the tensorboard logs that we downloaded from S3 directly.\n",
    "\n",
    "Note:  If you pointed Tensorboard to S3 directly, you must prepend this command with `S3_REGION=[your-region]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !S3_REGION=<your-region> tensorboard --port 6006 --logdir $outputs_dir/tensorboard/ # <== MAKE SURE YOU INCLUDE THE TRAILING `/`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While Tensorboard is running locally on your SageMaker Notebook instance, it is reading the training logs from Amazon S3.\n",
    "\n",
    "Navigate to https://workshop.notebook.[your-region].sagemaker.aws/proxy/6006/  <== MAKE SURE YOU INCLUDE THE TRAILING SLASH\n",
    "\n",
    "_Note:  Make sure you copy the trailing `/` in the link above.  If you see no data, you are likely not using the correct S3 bucket above._\n",
    "\n",
    "![Tensorboard](img/tensorboard.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stop Tensorboard\n",
    "Once you are done, hit `Kernel => Stop` to stop the running `Tensorboard` process in this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
